import os
import argparse

import mne
import math
import time
import json

import numpy as np

from scipy.signal import butter, filtfilt

from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace

from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, matthews_corrcoef

start = time.time()

parser = argparse.ArgumentParser()
parser.add_argument('-s', default = None)
parser.add_argument('-c', default = None, type = int)
args = parser.parse_args() 
print(args.s)
print(args.c)

print(__doc__)

##------------------------------------------------------------------------------

def apply_filter(data, b, a):
    r = filtfilt(b=b, a=a, x=data)
    return r
##------------------------------------------------------------------------------

repository_base = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

base = os.path.join(repository_base, "eeg")
save_base = os.path.join(repository_base, "results")

if not os.path.exists(save_base):
    os.makedirs(save_base)

subject = 'sub-A'
if args.s is not None:
    subject = args.s

fnum = np.array([[1,4],
                 [2,5],
                 [3,6]])

trig_id = [2,8,32]

test_class = 1
if args.c is not None:
    test_class = args.c

Fs = 1000
fc = [1, 40]
resample = None

b,a = butter(N = 2, Wn = np.array(fc)/(Fs/2), btype = 'bandpass', output = 'ba')

tmin,tmax= -0.1,0.5
baseline=(-0.05,0)

tasks = ['low', 'low', 'mid', 'mid', 'high', 'high']

reject={'eeg':100e-6,'eog':500e-6}

event_id={'target':-100,'non_target':-500}

##------------------------------------------------------------------------------

t = []
nt = []

target = []
non_target = []

for i in range(len(fnum.ravel())):
    fname = os.path.join(base, subject, "eeg", "%s_task-%s_run-%d_eeg.vhdr"%(subject, tasks[i], fnum.ravel()[i]))
    print(fname)
    if np.any(fnum[test_class-1] == fnum.ravel()[i]):
        if type(target) == list:
            target = mne.io.read_raw_brainvision(fname,preload=True,eog=('hEOG','vEOG'))
            target = target.apply_function(apply_filter, channel_wise = True, b = b, a = a)
            t.append(fnum.ravel()[i])
        else:
            tmp = mne.io.read_raw_brainvision(fname,preload=True,eog=('hEOG','vEOG'))
            tmp = tmp.apply_function(apply_filter, channel_wise = True, b = b, a = a)
            target = mne.concatenate_raws([target,tmp])
            t.append(fnum.ravel()[i])
    else:
        if type(non_target) == list:
            non_target = mne.io.read_raw_brainvision(fname,preload=True,eog=('hEOG','vEOG'))
            non_target = non_target.apply_function(apply_filter, channel_wise = True, b = b, a = a)
            nt.append(fnum.ravel()[i])
        else:
            tmp = mne.io.read_raw_brainvision(fname,preload=True,eog=('hEOG','vEOG'))
            tmp = tmp.apply_function(apply_filter, channel_wise = True, b = b, a = a)
            non_target = mne.concatenate_raws([non_target,tmp])
            nt.append(fnum.ravel()[i])
            
if resample != None:
    target.resample(resample)        
    non_target.resample(resample)
        
target_eve = mne.events_from_annotations(target)         
non_target_eve = mne.events_from_annotations(non_target)

target_eve = mne.merge_events(target_eve[0],[trig_id[test_class-1]],event_id['target'],replace_events=True)
non_target_eve = mne.merge_events(non_target_eve[0],[trig_id[test_class-1]],event_id['non_target'],replace_events=True)
        
target_epochs = mne.Epochs(target,events=target_eve,event_id=event_id['target'],tmin=tmin,tmax=tmax, baseline=baseline, reject=reject,preload = True)
non_target_epochs = mne.Epochs(non_target,events=non_target_eve,event_id=event_id['non_target'],tmin=tmin,tmax=tmax, baseline=baseline, reject=reject,preload = True)

epochs = mne.concatenate_epochs([target_epochs,non_target_epochs])

epochs = epochs.copy().pick_types(eeg=True,eog=False)

clf = make_pipeline(XdawnCovariances(3),
                    TangentSpace(metric='riemann'),
                    LogisticRegression(penalty='l1', solver='liblinear', multi_class='ovr'))

epochs_data = epochs.get_data()
labels = epochs.events[:, -1]
preds = np.zeros(len(labels))

cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

preds = np.empty(len(labels))
for train, test in cv.split(epochs, labels):
    clf.fit(epochs_data[train], labels[train])
    preds[test] = clf.predict(epochs_data[test])

report = classification_report(labels, preds,target_names=['non-target','target'], output_dict=True)
print(report)


mcc = matthews_corrcoef(labels,preds)
print('MCC Score : ',round(mcc,2),'\n')
print('Time : ',math.floor((time.time() - start)/60),'m',math.floor((time.time() - start)%60),'s\n')

if os.path.exists(os.path.join(save_base, "%s_classification_scores.json"%subject)):
    with open(os.path.join(save_base, "%s_classification_scores.json"%subject), 'r') as f:
        data_json = json.load(f)
else:
    data_json = dict()
    
data_json[test_class] = dict()
data_json[test_class]['report'] = report
data_json[test_class]['mcc'] = mcc
data_json[test_class]['id'] = event_id
data_json[test_class]['labels'] = labels.tolist()
data_json[test_class]['preds'] = preds.tolist()

with open(os.path.join(save_base, "%s_classification_scores.json"%subject), 'w') as f:
    json.dump(data_json, f, indent=4)
